Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP tests and checkpointing fixes #26180

Merged
merged 26 commits into from
Sep 20, 2023
Merged

FSDP tests and checkpointing fixes #26180

merged 26 commits into from
Sep 20, 2023

Conversation

pacman100
Copy link
Contributor

@pacman100 pacman100 commented Sep 15, 2023

What does this PR do?

  1. Fixes certain bugs with checkpointing when using FSDP
  2. Adds tests for FSDP integration in Trainer.
  3. Different combination runs to check resuming from checkpoints work as expected.

Below we will run the different combinations of FSDP SHARDING_STRATEGY and STATE_DICT_TYPE for the run_glue.py transformers example
Initial setup:

cd transformers
export CUDA_VISISBLE_DEVICES=0,1
export TASK_NAME=mrpc

a. FULL_SHARD + FULL_STATE_DICT

i. command to run:

torchrun --nnodes 1 --nproc-per-node 2 ./examples/pytorch/text-classification/run_glue.py --model_name_or_path bert-base-cased  --task_name $TASK_NAME  --do_train  --do_eval  --max_seq_length 128  --per_device_train_batch_size 16  --learning_rate 5e-5  --num_train_epochs 3  --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --lr_scheduler_type cosine --save_strategy "epoch" --evaluation_strategy "epoch" --logging_steps 1 --fsdp "full_shard auto_wrap"  --fsdp_transformer_layer_cls_to_wrap BertLayer --bf16

Kill the process after epoch 1. Run the above command with --resume_from_checkpoint as below:

torchrun --nnodes 1 --nproc-per-node 2 ./examples/pytorch/text-classification/run_glue.py --model_name_or_path bert-base-cased  --task_name $TASK_NAME  --do_train  --do_eval  --max_seq_length 128  --per_device_train_batch_size 16  --learning_rate 5e-5  --num_train_epochs 3  --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --lr_scheduler_type cosine --save_strategy "epoch" --evaluation_strategy "epoch" --logging_steps 1 --fsdp "full_shard auto_wrap"  --fsdp_transformer_layer_cls_to_wrap BertLayer --bf16  --resume_from_checkpoint /tmp/$TASK_NAME/checkpoint-115/

iii. Plots of loss and learning rate:
Screenshot 2023-09-15 at 2 07 43 PM

b. SHARD_GRAD_OP + FULL_STATE_DICT
Same as above but with the following cmd arg --fsdp "shard_grad_op auto_wrap"

Plots:
Screenshot 2023-09-15 at 2 09 29 PM

c. FULL_SHARD + SHARDED_STATE_DICT

i. Here, we will need to use the accelerate launcher as the option to choose SHARDED_STATE_DICT is currently available via accelerate config. Below is the config file fsdp_config.yaml:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_transformer_layer_cls_to_wrap: BertLayer
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

ii. command to run:

accelerate launch --config_file "fsdp_config.yaml" ./examples/pytorch/text-classification/run_glue.py --model_name_or_path bert-base-cased --task_name $TASK_NAME --do_train --do_eval --max_seq_length 128 --per_device_train_batch_size 16 --learning_rate 5e-5 --num_train_epochs 3 --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --lr_scheduler_type cosine --save_strategy "epoch" --evaluation_strategy "epoch" --logging_steps 1

Kill the process after epoch 1. Run the above command with --resume_from_checkpoint as below:

accelerate launch --config_file "fsdp_config.yaml" ./examples/pytorch/text-classification/run_glue.py --model_name_or_path bert-base-cased --task_name $TASK_NAME --do_eval --max_seq_length 128 --per_device_train_batch_size 16 --learning_rate 5e-5 --num_train_epochs 5 --output_dir /tmp/$TASK_NAME/ --overwrite_output_dir --lr_scheduler_type cosine --save_strategy "epoch" --evaluation_strategy "epoch" --logging_steps 1  --resume_from_checkpoint /tmp/$TASK_NAME/checkpoint-115/

iii. Plots:
Screenshot 2023-09-15 at 2 14 16 PM

d. SHARD_GRAD_OP + SHARDED_STATE_DICT
Just run the accelerate config command and choose SHARD_GRAD_OP Sharding strategy and get fsdp_config.yaml similar to the above case. The rest is the same.

Plots:
Screenshot 2023-09-15 at 2 16 02 PM

@pacman100 pacman100 marked this pull request as ready for review September 15, 2023 08:12
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 15, 2023

The documentation is not available anymore as the PR was closed or merged.

@jphme
Copy link
Contributor

jphme commented Sep 17, 2023

Just a short feedback, when trying to resume from a checkpoint with SHARDED_STATE_DICT (see #26186 for setup/details) with this PR, i get a Cuda OOM error, full stacktrace below.

Full Stacktrace
  File "/workspace/axolotl/scripts/finetune.py", line 287, in <module>
Traceback (most recent call last):
  File "/workspace/axolotl/scripts/finetune.py", line 287, in <module>
Traceback (most recent call last):
    fire.Fire(do_cli)
  File "/workspace/axolotl/scripts/finetune.py", line 287, in <module>
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
    fire.Fire(do_cli)
      File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
fire.Fire(do_cli)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
          File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
component, remaining_args = _CallAndUpdateTrace(
component_trace = _Fire(component, args, parsed_flag_args, context, name)  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 475, in _Fire
    component = fn(*varargs, **kwargs)
      File "/workspace/axolotl/scripts/finetune.py", line 283, in do_cli
component, remaining_args = _CallAndUpdateTrace(
      File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component, remaining_args = _CallAndUpdateTrace(train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

  File "/workspace/axolotl/src/axolotl/train.py", line 116, in train
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/workspace/axolotl/scripts/finetune.py", line 283, in do_cli
        trainer.train(resume_from_checkpoint=resume_from_checkpoint)component = fn(*varargs, **kwargs)

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1575, in train
  File "/workspace/axolotl/scripts/finetune.py", line 283, in do_cli
    train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
      File "/workspace/axolotl/src/axolotl/train.py", line 116, in train
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
  File "/workspace/axolotl/src/axolotl/train.py", line 116, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
      File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1575, in train
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1575, in train
    return inner_training_loop(
        return inner_training_loop(  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1876, in _inner_training_loop
return inner_training_loop(

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1876, in _inner_training_loop
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 1876, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2768, in training_step
    tr_loss_step = self.training_step(model, inputs)
tr_loss_step = self.training_step(model, inputs)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2768, in training_step
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/trainer.py", line 2768, in training_step
    self.accelerator.backward(loss)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1963, in backward
    self.scaler.scale(loss).backward(**kwargs)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
        self.accelerator.backward(loss)
self.accelerator.backward(loss)
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1963, in backward
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/accelerate/accelerator.py", line 1963, in backward
    torch.autograd.backward(
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
        self.scaler.scale(loss).backward(**kwargs)self.scaler.scale(loss).backward(**kwargs)

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
        torch.autograd.backward(torch.autograd.backward(

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
            return user_fn(self, *args)Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward passVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 157, in backward
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 274, in apply
        return user_fn(self, *args)return user_fn(self, *args)

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 157, in backward
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/utils/checkpoint.py", line 157, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)
      File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)torch.autograd.backward(outputs_with_grad, args_with_grad)

  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
  File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/torch/autograd/__init__.py", line 200, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
torch.cuda.OutOfMemoryError:         CUDA out of memory. Tried to allocate 172.00 MiB (GPU 3; 31.74 GiB total capacity; 30.59 GiB already allocated; 168.38 MiB free; 31.12 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONFVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward passVariable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

@pacman100
Copy link
Contributor Author

pacman100 commented Sep 18, 2023

Hello @jphme, I do notice an increase in GPU memory consumption of about 600MB for above tests when resuming from checkpoint saved via SHARDED_STATE_DICT. However, that needs to be resolved by PyTorch team as it does not pertain to the integration. Could you raise an issue with PyTorch repo: https://github.com/pytorch/pytorch/issues regarding this?

@jphme
Copy link
Contributor

jphme commented Sep 18, 2023

Hello @jphme, I do notice an increase in GPU memory consumption of about 600MB for above tests when using SHARDED_STATE_DICT. However, that needs to be resolved by PyTorch team as it does not pertain to the integration. Could you raise an issue with PyTorch repo: https://github.com/pytorch/pytorch/issues regarding this?

Hi @pacman100 sure - but just to clarify: The training started (and ran until the checkpoint) without problems and its also possible to extract the model after the training with trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") on the same instance/setup.

So you mean that specifically for restarting from a SHARDED_STATE_DICT checkpoint more VRAM is needed and we can do nothing about it?

This is quite dangerous as everyone tunes their runs so VRAM is maxed and that would mean that many runs can't be restarted from checkpoints...

EDIT: Ok I re-read your post - in my case the checkpoint was indeed created with the main branch and I only tried to restart with this PR; if the PR generally increases VRAM consumption that would explain it.

But then I don't understand whats exactly the Pytorch issue. And is there no way (with offloading) to avoid the increased VRAM consumption as everything besides checkpointing (training, model extraction) worked fine for me? (Sorry if i am a bit slow understanding, still new to FSDP/Torch - many thanks for your work on this!)

@pacman100
Copy link
Contributor Author

if the PR generally increases VRAM consumption that would explain it.

This PR doesn't increase VRAM consumption. Internally, it is calling the Torch utility here:

https://github.com/huggingface/accelerate/blob/a87c95da9e3b416fb10a0e7dac7d397c015c3ed5/src/accelerate/utils/fsdp_utils.py#L114-L130

and here:

https://github.com/huggingface/accelerate/blob/a87c95da9e3b416fb10a0e7dac7d397c015c3ed5/src/accelerate/utils/fsdp_utils.py#L178-L192

These are probably leading to the increased VRAM consumption.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left a few nits, mostly on the test 🤗

src/transformers/testing_utils.py Outdated Show resolved Hide resolved
tests/fsdp/test_fsdp.py Show resolved Hide resolved
tests/fsdp/test_fsdp.py Show resolved Hide resolved
src/transformers/modeling_utils.py Show resolved Hide resolved
src/transformers/trainer.py Show resolved Hide resolved
src/transformers/trainer.py Show resolved Hide resolved
tests/fsdp/test_fsdp.py Outdated Show resolved Hide resolved
tests/fsdp/test_fsdp.py Outdated Show resolved Hide resolved
tests/fsdp/test_fsdp.py Outdated Show resolved Hide resolved
tests/fsdp/test_fsdp.py Outdated Show resolved Hide resolved
pacman100 and others added 2 commits September 19, 2023 10:19
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@jphme
Copy link
Contributor

jphme commented Sep 19, 2023

So just for further reference (because other people are starting to have the same issue and commented on my closed issue): Checkpoints are currently of no use with SHARDED_STATE_DICT if maxing out Vram during training, because you will run OOM when trying to continue, even if everything else (starting training, creating checkpoint, saving model at the end after converting to FULL_STATE_DICT) works fine.

Will try with torch nightly if I have the opportunity (there seems to be a new env that could help), unfortunately very busy currently.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking forward to the fixes this brings!

@pacman100 pacman100 merged commit 382ba67 into main Sep 20, 2023
@pacman100 pacman100 deleted the smangrul/fsdp-tests branch September 20, 2023 04:56
hzhiyuan pushed a commit to hzhiyuan/transformers that referenced this pull request Sep 20, 2023
* add fsdp tests

* Update test_fsdp.py

* Update test_fsdp.py

* fixes

* checks

* Update trainer.py

* fix

* fixes for saving/resuming checkpoints

* fixes

* add tests and delete debug statements

* fixing tests

* Update test_fsdp.py

* fix tests

* fix tests

* minor nits

* fix code style and quality

* refactor and modularize test code

* reduce the time of tests

* reduce the test time

* fix test

* reduce test time

* reduce test time

* fix failing tests

* fix

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* resolve comments

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
* add fsdp tests

* Update test_fsdp.py

* Update test_fsdp.py

* fixes

* checks

* Update trainer.py

* fix

* fixes for saving/resuming checkpoints

* fixes

* add tests and delete debug statements

* fixing tests

* Update test_fsdp.py

* fix tests

* fix tests

* minor nits

* fix code style and quality

* refactor and modularize test code

* reduce the time of tests

* reduce the test time

* fix test

* reduce test time

* reduce test time

* fix failing tests

* fix

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* resolve comments

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@jmzeng
Copy link

jmzeng commented Sep 27, 2023

Hi, has this fix been merged into the new the new transformers v4.33.3?

@LysandreJik
Copy link
Member

Hey @jmzeng, it is not part of v4.33.3 but will be part of v4.34.0 which will be released early next week.

In the meantime, you can install from source:

pip install git+https://github.com/huggingface/transformers

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants